import torch
import torch.nn as nn
from torchvision.transforms import Normalize

# 如果是三通道图片，则算出来的mean、std有三个

torch.random.manual_seed(42)
img = torch.randint(0, 256, (3, 1, 3, 3), dtype=torch.float32)
print(type(img))
print("原始图像数据:\n", img)
print("-"*50)


print("方法一，计算每张图片的均值和标准差，再求平均作为整体的均值和标准差：")
def cal_mean_std(ds):
    mean = 0.
    std = 0.
    for img in ds: # 遍历这3张图片，每张图片img.shape=[1,28,28]
        mean += img.mean(dim=(1, 2))
        std += img.std(dim=(1, 2))
    mean /= len(ds)
    std /= len(ds)
    return mean, std

print(cal_mean_std(img))
print("-"*100)


print("方法二，直接把所有图片的像素点一起计算，作为整体的均值和标准差：")
mean = img.mean(dim=(0,2,3))
std = img.std(dim=(0,2,3))
print("原始图像均值:", mean)
print(mean.shape)
print("原始图像标准差:", std)
print("-"*100)

"""
为什么两个方法算的std有区别：
方法一是算了每张图自己的std，再求平均。方法二是把3*1*3*3=27个像素点一起求了
"""

# 创建一个Normalize变换，使用计算得到的均值和标准差
normalize_transform = nn.Sequential(
    #Normalize(mean=mean.tolist(), std=std.tolist())    # 不用这个
    Normalize(mean=cal_mean_std(img)[0].tolist(), std=cal_mean_std(img)[1].tolist()) # 用整体的均值和标准差，保证图片都是同一分布
)

# 应用Normalize变换到图像数据上
normalized_img = normalize_transform(img)

# 打印标准化后的图像数据
print("标准化后的图像数据:\n", normalized_img)
print(cal_mean_std(normalized_img))
